Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support soft target in softmax_cross_entropy #5595

Merged
merged 14 commits into from Oct 31, 2019
Merged

Conversation

anaruse
Copy link
Contributor

@anaruse anaruse commented Oct 29, 2018

This PR aims to support "soft target" in softmax_cross_entropy.

Current softmax_cross_entropy implementation support "hard target" but does not support "soft target" that is becoming popular as a method to mitigate over-fitting. This PR allows users to use "soft target" in softmax_cross_entropy as follows.

soft_target_loss = F.softmax_cross_entropy(x, t, soft_target=soft_target)

The soft target loss is KL divergence.

@anaruse
Copy link
Contributor Author

anaruse commented Oct 29, 2018

Or, would it be better to implement this as different function such as softmax_kl_divergence?

@anaruse
Copy link
Contributor Author

anaruse commented Oct 30, 2018

I've fixed the PR, so that it uses argument t for both hard and soft target. Whether it is hard or soft target is determined by ndim and shape of x and t.

@anaruse anaruse changed the title [WIP] Support soft target in softmax_cross_entropy Support soft target in softmax_cross_entropy Nov 1, 2018
@beam2d beam2d self-assigned this Nov 13, 2018
Copy link
Member

@beam2d beam2d left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The design and implementation looks good. I added some minor comments.

t_type.ndim == x_type.ndim - 1,
if x_type.ndim == t_type.ndim and x_type.shape == t_type.shape:
# assume t is soft_target
self.soft_target = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep check_type_forward not having side effect. This method may be skipped by setting CHAINER_TYPE_CHECK=0.

x_type.dtype.kind == 'f',
t_type.dtype.kind == 'i',
t_type.ndim == x_type.ndim - 1,
if x_type.ndim == t_type.ndim and x_type.shape == t_type.shape:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's better to branch based on dtype kind and then check ndim/shape with expect. It will produce an error message that matches the user's intent.

def _soft_target_loss(self, xp, x, t, log_y):
kl_d = xp.sum(t * (xp.log(t + self.eps) - log_y), axis=1)
if self.reduce == 'mean':
self._coeff = 1.0 / (numpy.prod(x.shape) / x.shape[1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self._coeff = 1.0 / (numpy.prod(x.shape) / x.shape[1])
self._coeff = 1.0 / (x.size / x.shape[1])

size can be used to get the total number of elements.

return kl_d.reshape(()),
else:
shape = (x.shape[0],) + x.shape[2:]
return kl_d.reshape(shape),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this reshape needed?

self.check_backward_options = {}

def check_forward(self, xp):
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pass
raise NotImplementedError

to ensure this method is overridden.

t_hard_shape = (self.nb,) + self.shape[1:]
self.t_hard = numpy.random.randint(
0, self.shape[0], t_hard_shape).astype(numpy.int32)
t = numpy.zeros(numpy.prod(self.x.shape)).astype(self.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
t = numpy.zeros(numpy.prod(self.x.shape)).astype(self.dtype)
t = numpy.zeros(self.x.size).astype(self.dtype)

@anaruse
Copy link
Contributor Author

anaruse commented Nov 20, 2018

Thanks for your comments. I've just fixed the branch based on your feedback. Please review it again.

@beam2d
Copy link
Member

beam2d commented Nov 27, 2018

Thank you for the updates. Looks good to me. Could you add a description of the soft target support to the docstring?

@anaruse
Copy link
Contributor Author

anaruse commented Nov 27, 2018

I've just added a description of the soft target support. Please check it again.

@toslunar
Copy link
Member

Could you resolve merge conflicts?

@anaruse
Copy link
Contributor Author

anaruse commented Jan 15, 2019

Sorry for the late reply. I've just resolved the conflict to master branch. Could you check it again?

@stale
Copy link

stale bot commented Apr 15, 2019

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs. Thank you for your contributions.

@stale stale bot added the stale Not updated for a longer period of time. label Apr 15, 2019
@beam2d
Copy link
Member

beam2d commented Apr 16, 2019

Bump. Sorry for the late response. Could you resolve the conflict again?

@stale stale bot removed the stale Not updated for a longer period of time. label Apr 16, 2019
@anaruse
Copy link
Contributor Author

anaruse commented Apr 17, 2019

I've just resolved the conflicts with the master branch. Could you check it?

@beam2d
Copy link
Member

beam2d commented Jul 3, 2019

Sorry for the late reply. The code looks good to me, but I have a bit concern on the naming now. The function named "cross entropy" that computes KL divergence is confusing and it will quite likely be misused. You asked above if it is better to name it softmax_kl_divergence, and that sounds better to me, but I also think it's overkill to add a new function, based on the current implementation. How about adding an option to turn on/off the negative entropy term (i.e., switching between cross-entropy mode and KL divergence mode)? The option only affects the soft target case because the cross-entropy and KL divergence are the same for the hard target case.

@@ -223,21 +244,32 @@ def forward_gpu(self, inputs):
ret = ret.reshape(t.shape)
return ret,

def _soft_target_loss(self, xp, x, t, log_y):
kl_d = xp.sum(t * (xp.log(t + self.eps) - log_y), axis=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To compute the cross entropy,

Suggested change
kl_d = xp.sum(t * (xp.log(t + self.eps) - log_y), axis=1)
____ = -xp.sum(t * log_y, axis=1)

Copy link
Contributor Author

@anaruse anaruse Jul 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, I'm wondering if this is correct, since it fails at the test below.

class TestSoftTargetExpectNearZero(BaseSoftTarget, unittest.TestCase):

This test uses output of softmax as soft target label, so output of softmax_cross_entropy is expected to be zero or almost zero. But softmax_cross_entropy returns non-zero value when 'cross-entropy' is used as soft target loss calculation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 1ab2632.

@toslunar
Copy link
Member

toslunar commented Jul 4, 2019

BTW, chainer.distributions.Categorical is available.

@anaruse
Copy link
Contributor Author

anaruse commented Jul 9, 2019

All right, I added an option 'soft_target_loss' so that you can opt which loss calculation method to use for soft target loss: 'cross-entropy' or 'kl-divergence'. What would you think about this option?

@beam2d
Copy link
Member

beam2d commented Jul 16, 2019

Thanks for the fix!
CI, test this please.

@chainer-ci
Copy link
Member

Jenkins CI test (for commit 1ab2632, target branch master) failed with status FAILURE.

if self.soft_target_loss == 'kl-divergence':
ret = xp.sum(t * (xp.log(t + self.eps) - log_y), axis=1)
else:
ret = -xp.sum(t * log_y), axis=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks there is a syntax error here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry, I was careless.. will fix it soon.

@beam2d
Copy link
Member

beam2d commented Jul 29, 2019

Jenkins, test this please.

@chainer-ci
Copy link
Member

Jenkins CI test (for commit f847214, target branch master) failed with status FAILURE.

@beam2d
Copy link
Member

beam2d commented Jul 29, 2019

It looks the test is still failing. Could you check it?

@anaruse
Copy link
Contributor Author

anaruse commented Jul 30, 2019

The CI test fails at TestSoftTargetExpectNearZero (

class TestSoftTargetExpectNearZero(BaseSoftTarget, unittest.TestCase):
) when 'cross-entropy' is used to compute soft target loss.

The test above expects a loss value to be zero, but it becomes non-zero when 'cross-entropy' is used. I think the cause is that this test is not appropriate for 'cross-entropy' or computation method of 'cross-entropy' is not correct.

What would you think on this?

@anaruse
Copy link
Contributor Author

anaruse commented Oct 9, 2019

Sorry for being very late.
I fixed an issue of unit test TestSoftTargetExpectNearZero when cross-entropy is selected as soft target loss calculation by dividing the test to two tests, one for kl-divergence and another for cross-entropy. Perhaps there is no problem remaining. Could you check this again?

@beam2d
Copy link
Member

beam2d commented Oct 17, 2019

Jenkins and flexCI, test this please.

@chainer-ci
Copy link
Member

Jenkins CI test (for commit 620b55d, target branch master) succeeded!

@toslunar toslunar self-requested a review October 29, 2019 06:30
Copy link
Member

@toslunar toslunar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@beam2d beam2d merged commit 2659ca2 into chainer:master Oct 31, 2019
@beam2d beam2d added the cat:feature Implementation that introduces new interfaces. label Oct 31, 2019
@beam2d beam2d added this to the v7.0.0 milestone Oct 31, 2019
@beam2d
Copy link
Member

beam2d commented Oct 31, 2019

Thank you!!!

@anaruse
Copy link
Contributor Author

anaruse commented Oct 31, 2019

Thank you for merging the PR !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cat:feature Implementation that introduces new interfaces.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants